{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf96fcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp extract_acoustic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a5f3e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afc34d8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torchaudio\n",
    "import gc\n",
    "\n",
    "from pathlib import Path\n",
    "from fastcore.script import *\n",
    "from fastprogress import progress_bar, master_bar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc96663a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# unpacked small.tar should go here:\n",
    "datadir = Path('/mnt/')\n",
    "# you can download it downloaded from\n",
    "# https://github.com/facebookresearch/libri-light/blob/main/data_preparation/README.md"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a486d4c3",
   "metadata": {},
   "source": [
    "# Acoustic token extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99fac6f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def load(fname, newsr=24000):\n",
    "    \"\"\"Load an audio file to the GPU and resample to `newsr`.\"\"\"\n",
    "    x, sr = torchaudio.load(fname)\n",
    "    _tform = torchaudio.transforms.Resample(sr, newsr)\n",
    "    return _tform(x).cuda().unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a6f8857",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def load_model():\n",
    "    \"Load the pretrained EnCodec model\"\n",
    "    from encodec.model import EncodecModel\n",
    "    model = EncodecModel.encodec_model_24khz()\n",
    "    model.set_target_bandwidth(1.5)\n",
    "    model.cuda().eval();\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41c2ad6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def extract_Atoks(model, audio):\n",
    "    \"\"\"Extract EnCodec tokens for the given `audio` tensor (or file path)\n",
    "    using the given `model` (see `load_model`).\"\"\"\n",
    "    if isinstance(audio, (Path, str)):\n",
    "        audio = load(audio)\n",
    "    with torch.no_grad():\n",
    "        frames = torch.cat([model.encode(segment)[0][0]\n",
    "                            for segment in torch.split(audio, 320*20000, dim=-1)], dim=-1)\n",
    "    return frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0cdaa92",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@call_parse\n",
    "def extract_acoustic(\n",
    "        srcdir:Path,  # source dir, should contain *.flac files\n",
    "        outdir:Path,  # output dir, will get the *.encodec files\n",
    "    ): \n",
    "    \"Convert audio files to .encodec files with tensors of tokens\"\n",
    "    model = load_model()\n",
    "    outdir.mkdir(exist_ok=True, parents=True)\n",
    "    for name in progress_bar(list(srcdir.rglob('*.flac'))):\n",
    "        outname = outdir/name.with_suffix('.encodec').name\n",
    "        tokens = extract_Atoks(model, name)\n",
    "        torch.save(tokens, outname)\n",
    "        del tokens\n",
    "        gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd0370b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      <progress value='131' class='' max='131' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [131/131 05:38&lt;00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# process all files for speaker 1401\n",
    "model = load_model()\n",
    "extract_acoustic(model, datadir/'small/1401', datadir/'acoustic-1401')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ffa76be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "78M\t/mnt/acoustic-1401/\r\n"
     ]
    }
   ],
   "source": [
    "!du -hs {datadir}/acoustic-1401/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5167f334",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
